import torch
import torch.nn.functional as F
import numpy as np

from utils import split_vt_feats, split_vt_feats_optimized, gen_vt_target, gen_vt_target_1
from utils import cal_vt_distribution, load_init_embs, get_text_feat, get_text_feat_1, get_results
from scipy.signal import find_peaks
from scipy.stats import norm


def get_binary_label(types, normal_str):
    binary_label = torch.ones(len(types)).cuda()
    for i, type_ in enumerate(types):
        if type_[0] == normal_str:
            binary_label[i] = 0
    return binary_label


def mil_scale_loss(logits, types, normal_label, seq_len, bce_crit, betas):
    label = get_binary_label(types, normal_label)
    length = len(logits)
    loss_mil = 0
    seq_len_ = seq_len
    for i in range(length):
        loss_mil += betas[i] * mil_loss(logits[i], label, seq_len_, bce_crit)
        seq_len_ = (seq_len // (i + 2))
        seq_len_[seq_len_ == 0] = 1
    return loss_mil


def mil_scale_loss_edl(logits, types, normal_label, seq_len, bce_crit, betas, sf, epoch, feat):
    label = get_binary_label(types, normal_label)
    loss_mil = betas[0] * mil_loss_edl(logits[0], label, seq_len, bce_crit, sf, epoch, feat)
    return loss_mil


def mil_scale_loss_single_sf(epoch, logits, edl_logits, types, normal_label, seq_len, bce_crit, betas, sf, ori_feat,
                             semantic_feat):
    label = get_binary_label(types, normal_label)
    loss_mil = betas[0] * mil_loss_single_sf(epoch, logits[0], edl_logits[0], label, seq_len, bce_crit, sf, ori_feat,
                                             semantic_feat)
    return loss_mil


def abnor_prompt_loss(model, text_model, logits, v_feat, kl_crit, lamdas, types, label_list, normal_label):
    logit_scale = model.logit_scale.exp()
    pos_text_feats, rel_text_feats, nor_text_feat = \
        get_text_feat(text_model, types, label_list,
                      model.lprompt_pre, model.lprompt_rear)
    v2t_logits = []
    for i in range(len(logits)):
        video_feats, text_feats, tfeat_labels = \
            split_vt_feats(v_feat[i], pos_text_feats, rel_text_feats, nor_text_feat,
                           logits[i], types, label_list, normal_label)
        v2t_logits.append(cal_vt_distribution(
            video_feats, text_feats, logit_scale))
    ground_truth = torch.tensor(gen_vt_target(
        tfeat_labels, text_feats, normal_label), dtype=v_feat[0].dtype).cuda()

    loss_apl = 0
    for i in range(len(logits)):
        loss_apl += lamdas[i] * \
                    kldiv_loss(v2t_logits[i], ground_truth, kl_crit)
    return loss_apl


def abnor_prompt_loss_1(model, text_model, logits, v_feat, kl_crit, lamdas, types, label_list, normal_label, factor1,
                        factor2):
    logit_scale = model.logit_scale.exp()
    tf_boys, nor_feat, type_flags = get_text_feat_1(
        text_model, types, label_list, model.lprompt_pre, model.lprompt_rear)
    v2t_logits_1 = []
    for i in range(len(logits)):
        video_feats_1 = split_vt_feats_optimized(
            v_feat[i], logits[i], types, label_list, normal_label)
        v2t_logits_1.append(cal_vt_distribution(
            video_feats_1, tf_boys, logit_scale))
    ground_truth_1 = torch.tensor(gen_vt_target_1(
        types, type_flags, tf_boys, nor_feat, normal_label, factor1=factor1, factor2=factor2),
        dtype=v_feat[0].dtype).cuda()

    loss_apl_1 = 0
    for i in range(len(logits)):
        loss_apl_1 += lamdas[i] * \
                      kldiv_loss(v2t_logits_1[i], ground_truth_1, kl_crit)
    return loss_apl_1


def prompt_constrain_loss(model, text_model, label_list, cos_crit, alphas):
    pre_len = model.lprompt_pre.shape[1]
    rear_len = model.lprompt_rear.shape[1]
    batch_size = model.lprompt_pre.shape[0]
    dim = model.lprompt_pre.shape[2]
    init_embs = load_init_embs(text_model, label_list)
    pc_loss_pre = cos_crit(model.lprompt_pre.reshape(-1, dim),
                           init_embs.repeat(1, pre_len, 1).reshape(-1, dim),
                           torch.ones((batch_size * pre_len,), dtype=torch.float32).cuda())
    pc_loss_rear = cos_crit(model.lprompt_rear.reshape(-1, dim),
                            init_embs.repeat(1, rear_len, 1).reshape(-1, dim),
                            torch.ones((batch_size * rear_len,), dtype=torch.float32).cuda())
    pc_loss = pc_loss_pre * alphas[0] + pc_loss_rear * alphas[1]
    return pc_loss


def mil_loss(logits, label, seq_len, criterion):
    logits = logits.squeeze()
    ins_logits = torch.zeros(0).cuda()
    for i in range(logits.shape[0]):
        if label[i] == 0:
            tmp, _ = torch.topk(logits[i][:seq_len[i]], k=1, largest=True)
        else:
            tmp, _ = torch.topk(logits[i][:seq_len[i]], k=int(
                seq_len[i] // 16 + 1), largest=True)
        tmp = torch.mean(tmp).view(1)
        ins_logits = torch.cat((ins_logits, tmp))

    mask_ones = torch.isclose(ins_logits, torch.tensor(1.0), atol=1e-7)
    mask_zeros = torch.isclose(ins_logits, torch.tensor(0.0), atol=1e-7)
    ins_logits = torch.where(mask_ones, ins_logits - 1e-7, ins_logits)
    ins_logits = torch.where(mask_zeros, ins_logits + 1e-7, ins_logits)

    clsloss = criterion(ins_logits, label)
    return clsloss


def generate_decreasing_sequence(n):
    if n <= 0:
        raise ValueError("Length of the sequence must be greater than 0")
    random_numbers = np.random.rand(n)
    sorted_numbers = np.sort(random_numbers)[::-1]
    sequence = sorted_numbers / np.sum(sorted_numbers)

    return sequence


def mil_loss_edl(logits, label, seq_len, criterion, sf, epoch, feat):
    logits = logits.squeeze()
    ins_logits = torch.zeros(0).cuda()
    prob, uncertainty = get_results(logits)
    uncertainty = uncertainty.cpu().detach()
    sim = pairwise_cosine_similarity(feat)
    label = label.tolist()
    bias = 0
    regular_loss = 0
    for i in range(logits.shape[0]):
        if label[i + bias] == 0:
            _, idx = torch.topk(prob[i][:seq_len[i], 0].squeeze(), k=1, largest=True)
        else:
            idx = torch.tensor([sf[i]])
        tmp = logits[i][idx]
        tmp = torch.mean(tmp, dim=0, keepdim=True)
        ins_logits = torch.cat((ins_logits, tmp))
        mask = torch.ones(logits[i].shape, dtype=torch.bool).cuda()
        mask[idx] = False
        mask[seq_len[i]:] = False
        tmp = logits[i][mask]
        regular_loss += torch.mean(regular_loss_fn(tmp, 2))

    mask_ones = torch.isclose(ins_logits, torch.tensor(1.0), atol=1e-7)
    mask_zeros = torch.isclose(ins_logits, torch.tensor(0.0), atol=1e-7)
    ins_logits = torch.where(mask_ones, ins_logits - 1e-7, ins_logits)
    ins_logits = torch.where(mask_zeros, ins_logits + 1e-7, ins_logits)

    label = torch.tensor(label).cuda()
    label = torch.stack([label, 1 - label], dim=1).float()
    clsloss = edl_digamma_loss(ins_logits, label, epoch, 2, 10)
    return clsloss + regular_loss / 5e2


def find_nearest_interval(lst, num):
    smaller = None
    larger = None
    for x in lst:
        if x < num:
            if smaller is None or num - x < num - smaller:
                smaller = x
        elif x > num:
            if larger is None or x - num < larger - num:
                larger = x
    return [smaller, larger]


def get_abn_interval(sf, sim, uncertainty):
    unc_var = torch.var(torch.tensor(uncertainty)[10:-10])
    peaks, properties = find_peaks(uncertainty, prominence=0)
    sim_peaks = find_peaks(-sim)[0]
    interval = find_nearest_interval(peaks, sf)
    sim_interval = find_nearest_interval(sim_peaks, sf)
    interval = merge_intervals_1(interval, sim_interval)
    # print(properties['prominences'])
    # print(unc_var, unc_var*10)
    max_prominence_index = np.argmax(properties['prominences'])
    max_prominence_peak = peaks[max_prominence_index]
    min_idx = np.argmin(sim)
    start = np.max(np.array(
        [min_idx, max_prominence_peak, interval[1] if interval[1] is not None else 0]))
    return [start, sim.shape[0]]


def merge_intervals_1(interval1, interval2):
    start1, end1 = interval1
    start2, end2 = interval2

    if start1 is None:
        start = start2
    elif start2 is None:
        start = start1
    else:
        start = max(start1, start2)
    if end1 is None:
        end = end2
    elif end2 is None:
        end = end1
    else:
        end = min(end1, end2)

    return [start, end]


def edl_digamma_loss(
        output, target, epoch_num, num_classes, annealing_step, device=None
):
    evidence = F.leaky_relu(output, negative_slope=-0.01)  # F.relu(output)
    alpha = evidence + 1
    pos_factor = (target[:, 0].numel() -
                  target[:, 0].sum()) / target[:, 0].sum()
    factors = torch.where(target[:, 0] == 1, pos_factor, torch.tensor(1))
    loss = torch.mean(
        edl_loss(
            torch.digamma, target, alpha, epoch_num, num_classes, annealing_step, device
        ).squeeze() * factors
    )
    return loss


def mse_loss(y, alpha, epoch_num, num_classes, annealing_step, device=None):
    y = y.cuda()
    alpha = alpha.cuda()
    loglikelihood = loglikelihood_loss(y, alpha, device=device)

    annealing_coef = torch.min(
        torch.tensor(1.0, dtype=torch.float32),
        torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
    )

    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
    return loglikelihood + kl_div


def regular_loss_fn(output, num_classes):
    output = output.reshape(-1, 2)
    evidence = F.leaky_relu(output, negative_slope=-0.01)  # F.relu(output)
    alpha = evidence + 1
    y = torch.zeros_like(alpha).cuda()
    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = 1 * kl_divergence(kl_alpha, num_classes).cuda()
    return kl_div


def loglikelihood_loss(y, alpha, device=None):
    y = y.cuda()
    alpha = alpha.cuda()
    S = torch.sum(alpha, dim=1, keepdim=True)
    loglikelihood_err = torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True)
    loglikelihood_var = torch.sum(
        alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True
    )
    loglikelihood = loglikelihood_err + loglikelihood_var
    return loglikelihood


def edl_mse_loss(output, target, epoch_num, num_classes, annealing_step, device=None):
    evidence = F.relu(output)
    alpha = evidence + 1
    loss = torch.mean(
        mse_loss(target, alpha, epoch_num, num_classes, annealing_step, device=device)
    )
    return loss


def edl_loss(func, y, alpha, epoch_num, num_classes, annealing_step, device=None):
    y = y.cuda()
    alpha = alpha.cuda()
    S = torch.sum(alpha, dim=1, keepdim=True)

    A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)

    annealing_coef = torch.min(
        torch.tensor(1.0, dtype=torch.float32),
        torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
    )
    annealing_coef = 0
    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes).cuda()
    return A + kl_div


def kl_divergence(alpha, num_classes, device=None):
    ones = torch.ones([1, num_classes], dtype=torch.float32).cuda()
    sum_alpha = torch.sum(alpha, dim=1, keepdim=True)
    first_term = (
            torch.lgamma(sum_alpha)
            - torch.lgamma(alpha).sum(dim=1, keepdim=True)
            + torch.lgamma(ones).sum(dim=1, keepdim=True)
            - torch.lgamma(ones.sum(dim=1, keepdim=True))
    )
    second_term = (
        (alpha - ones)
        .mul(torch.digamma(alpha) - torch.digamma(sum_alpha))
        .sum(dim=1, keepdim=True)
    )
    kl = first_term + second_term
    return kl


def generate_indices(seq_len, sf, exclusion_radius=5):
    indices = []
    for i in range(seq_len):
        if i < sf - exclusion_radius or i > sf + exclusion_radius:
            indices.append(i)
    return indices


def trimmed_mean(data, percentage=0.1):
    data = torch.sort(data).values
    n = len(data)
    k = int(n * percentage)
    return torch.mean(data[k:n - k])


def weight_func(idxs, uncertainty, **kwargs):
    factors = uncertainty[idxs].max() - uncertainty[idxs] + 1e-3
    return (factors / factors.sum()).squeeze()


class SF_MIL_Loss:
    def __init__(self, epoch, logits, edl_logits, types, normal_label, seq_len, sf, ori_feat, ws):
        self.epoch = epoch
        self.logits = logits.squeeze()
        self.ori_label = get_binary_label(types, normal_label)
        self.labels = []
        self.weights = []
        self.seq_len = seq_len
        self.frame_anno = sf
        self.idx = None
        self.abn_left_idx, self.abn_right_idx = seq_len.max().item(), 0
        self.mil_logits = torch.zeros(0).cuda()
        self.frame_sim = pairwise_cosine_similarity(ori_feat)
        self.edl_prob, self.uncertainty = get_results(edl_logits)
        self.ws = ws

    def cal_loss(self):
        for i in range(self.logits.shape[0]):
            self.idx = i
            if self.ori_label[i] == 0:
                # For normal videos
                self.instance_loss(
                    ref_logits=self.logits[i],
                    sample_method=self.topk_sample,
                    is_abnormal=False,
                    k=1, largest=True)
            else:

                self.reset_abnormal_idx()
                if self.epoch > -1:
                    k = self.seq_len[i] // 16 + 1
                    self.instance_loss(
                        ref_logits=self.logits[i],
                        sample_method=self.topk_sample,
                        is_abnormal=True,
                        k=k, largest=True)

                self.instance_loss(
                    ref_logits=None,
                    sample_method=self.single_frame_sample,
                    is_abnormal=True,
                    frame=self.frame_anno[i])

                self.instance_loss(
                    ref_logits=self.uncertainty[i],
                    sample_method=self.abnormal_interval_sample,
                    is_abnormal=True,
                    weight_method=weight_func,
                    weight_factor=self.uncertainty[i],
                    frame=self.frame_anno[i],
                    reduce_func=torch.sum,
                    similarity=self.frame_sim[i])

                self.instance_loss(
                    ref_logits=self.frame_sim[i],
                    sample_method=self.abnormal_multi_interval_sample,
                    is_abnormal=True,
                    weight_method=weight_func,
                    weight_factor=self.uncertainty[i],
                    frame=self.frame_anno[i],
                    reduce_func=torch.sum,
                    similarity=self.frame_sim[i],
                    uncertainty=self.uncertainty[i])

                self.instance_loss(
                    ref_logits=self.uncertainty[i],
                    sample_method=self.normal_interval_sample,
                    is_abnormal=False,
                    frame=self.frame_anno[i],
                    similarity=self.frame_sim[i])

        self.mil_logits = clip_to_range(self.mil_logits)
        labels = torch.tensor(self.labels, dtype=torch.float).cuda()

        weights = torch.tensor(self.weights).cuda()
        crit = torch.nn.BCELoss(weight=weights)
        # crit = torch.nn.BCELoss()
        clsloss = torch.mean(crit(self.mil_logits, labels))
        return clsloss

    def instance_loss(self,
                      ref_logits,
                      sample_method,
                      is_abnormal: bool,
                      weight_method=None,
                      weight_factor=None,
                      reduce_func=torch.mean,
                      **kwargs):
        for data in sample_method(ref_logits, self.seq_len[self.idx], **kwargs):
            if data is not None:
                if isinstance(data, tuple):
                    sampled_idx = data[0]
                    flag = data[1]
                else:
                    sampled_idx = data
                    flag = None
                sampled_logit = self.logits[self.idx][sampled_idx]
                if is_abnormal:
                    self.update_abnormal_idx(sampled_idx)
                if sampled_logit.numel() > 1 and weight_method is not None:
                    factors = weight_method(sampled_idx, weight_factor)
                    sampled_logit *= factors
                self.update_logits(sampled_logit, reduce_func)
                self.labels.append(1 if is_abnormal else 0)
                self.weights.append(self.generate_weights(sample_method, is_abnormal, flag))

    def update_logits(self, logits, reduce_func):
        mean_logits = reduce_func(logits).view(1)
        self.mil_logits = torch.cat((self.mil_logits, mean_logits))

    def reset_abnormal_idx(self):
        self.abn_left_idx, self.abn_right_idx = self.seq_len.max().item(), 0

    def update_abnormal_idx(self, idxs):
        self.abn_left_idx = idxs.min().item() if idxs.min().item() < self.abn_left_idx else self.abn_left_idx
        self.abn_right_idx = idxs.max().item() + 1 if idxs.max().item() > self.abn_right_idx else self.abn_right_idx

    def generate_weights(self, sample_method, is_abnormal, flag):
        if sample_method == self.topk_sample:
            if is_abnormal:
                return self.ws[0]
            else:
                return 1
        elif sample_method == self.single_frame_sample:
            return 1
        elif sample_method == self.abnormal_interval_sample:
            return self.ws[1]
        elif sample_method == self.abnormal_multi_interval_sample:
            if flag == 0:
                return self.ws[2]
            elif flag == 1:
                return self.ws[1]
        elif sample_method == self.normal_interval_sample:
            if flag == 0:
                return self.ws[4]
            elif flag == 1:
                return self.ws[5]

    def normal_interval_sample(self, uncertainty, seq_len, frame, similarity, **kwargs):
        if self.abn_left_idx > 1:
            yield (torch.tensor([np.random.randint(0, self.abn_left_idx)]).cuda(), 0)
        if self.epoch > 2:
            yield (torch.topk(self.logits[self.idx][:seq_len], k=1, largest=False)[1], 1)
        if self.epoch > 2222 and frame < seq_len * 0.6:
            uc_numpy = uncertainty.detach().cpu().squeeze().numpy()
            peaks, _ = find_peaks(uc_numpy[:self.seq_len[self.idx]])
            if len(peaks) > 0:
                if max(peaks[-1], self.abn_right_idx) < seq_len:
                    yield \
                    torch.topk(self.logits[self.idx][max(peaks[-1], self.abn_right_idx):seq_len], k=1, largest=False)[1]
        yield None

    def topk_sample(self, logits, seq_len, k, largest, **kwargs):
        yield torch.topk(logits[:seq_len], k=k, largest=largest)[1]

    def single_frame_sample(self, logits, seq_len, frame, **kwargs):
        yield frame

    def abnormal_interval_sample(self, uncertainty, seq_len, frame, similarity, **kwargs):
        abn_left_b, abn_right_b = mine_abnormal_snippets(
            frame, seq_len, uncertainty, similarity)
        if abn_right_b - abn_left_b > 1:
            yield torch.arange(abn_left_b, abn_right_b)

    def abnormal_multi_interval_sample(self, logits, seq_len, similarity, uncertainty, frame):
        high_sim_indices = torch.nonzero(similarity[frame] > 0.97).view(-1)
        mean_value = trimmed_mean(similarity[frame])
        if 1 < high_sim_indices.shape[0] < 5 and mean_value < 0.95:
            for s_i in high_sim_indices:
                if s_i != frame.cuda() and torch.abs(s_i - frame.cuda()) > seq_len.cuda() * 0.2:
                    yield (s_i, 0)
                    try:
                        yield (next(self.abnormal_interval_sample(uncertainty, seq_len, s_i, similarity)), 1)
                    except StopIteration:
                        pass


def generate_right_tail_probs(n):
    x_values = np.linspace(0.5, 2, n)
    pdf_values = norm.pdf(x_values, loc=0, scale=1)
    return 1 - pdf_values


def softmax(np_array):
    return np.exp(np_array) / sum(np.exp(np_array))


def randint_with_prob(low, high, probs):
    values = list(range(low, high))
    probs /= probs.sum()
    return np.random.default_rng().choice(values, size=1, p=probs)


def clip_to_range(tensor, low=0, high=1, bias=1e-5):
    high_mask = torch.isclose(tensor, torch.tensor(float(high)), atol=bias)
    low_mask = torch.isclose(tensor, torch.tensor(float(low)), atol=bias)
    tensor = torch.where(high_mask, tensor - bias, tensor)
    tensor = torch.where(low_mask, tensor + bias, tensor)
    return tensor


def mine_abnormal_snippets(sf, seq_len, uncertainty, similarity):
    if sf == seq_len:
        sf -= 1
    uncertainty = uncertainty[:seq_len].squeeze()
    similarity = similarity[:seq_len][sf]
    threshold = torch.quantile(uncertainty, 0.05, interpolation='nearest')
    left_boundary, right_boundary = sf, sf
    if uncertainty[sf] <= threshold:
        peaks, _ = find_peaks(uncertainty.detach().cpu().numpy())
        for peak in peaks:
            if peak < sf:
                left_boundary = peak
            elif peak > sf:
                right_boundary = peak
                break
    return left_boundary, right_boundary


def mine_normal_snippets():
    pass


def gaussian_pdf(xs, mu=0.5, sigma=0.5):
    return [(1 / (np.sqrt(2 * np.pi) * sigma)) * np.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) for x in xs]


def kldiv_loss(preds, label, criterion):
    preds = F.softmax(preds, dim=1)
    preds = torch.log(preds + 1e-8)
    if torch.isnan(preds).any():
        loss = 0
    else:
        target = F.softmax(label * 10, dim=1)
        loss = criterion(preds, target)

    return loss


def temporal_smooth(arr):
    arr2 = torch.zeros_like(arr)
    arr2[:-1] = arr[1:]
    arr2[-1] = arr[-1]
    loss = torch.sum((arr2 - arr) ** 2)
    return loss


def temporal_sparsity(arr):
    loss = torch.sum(arr)
    return loss


def smooth(logits, seq_len, lamda=8e-5):
    smooth_mse = []
    for i in range(logits.shape[0]):
        tmp_logits = logits[i][:seq_len[i] - 1]
        sm_mse = temporal_smooth(tmp_logits)
        smooth_mse.append(sm_mse)
    smooth_mse = sum(smooth_mse) / len(smooth_mse)

    return smooth_mse * lamda


def sparsity(logits, seq_len, lamda=8e-5):
    spar_mse = []
    for i in range(logits.shape[0]):
        tmp_logits = logits[i][:seq_len[i]]
        sp_mse = temporal_sparsity(tmp_logits)
        spar_mse.append(sp_mse)
    spar_mse = sum(spar_mse) / len(spar_mse)

    return spar_mse * lamda


def smooth_sparsity(logits, seq_len, lamda=8e-5):
    smooth_mse = []
    spar_mse = []
    for i in range(logits.shape[0]):
        tmp_logits = logits[i][:seq_len[i]]
        sm_mse = temporal_smooth(tmp_logits)
        sp_mse = temporal_sparsity(tmp_logits)
        smooth_mse.append(sm_mse)
        spar_mse.append(sp_mse)
    smooth_mse = sum(smooth_mse) / len(smooth_mse)
    spar_mse = sum(spar_mse) / len(spar_mse)

    return (smooth_mse + spar_mse) * lamda


def pairwise_cosine_similarity(tensor):
    tensor_norm = F.normalize(tensor, p=2, dim=-1)
    cos_sim_matrix = torch.matmul(tensor_norm, tensor_norm.transpose(1, 2))
    return cos_sim_matrix
